Molecular Dynamics in Dex

import plot

This is more-or-less a port of Jax MD into Dex to see how molecular dynamics looks in Dex. For now, the structure of the two implementations is pretty close. However, details look different.

Math

def truncate(x: Float) -> Float = case x < 0.0 of True -> -floor(-x) False -> floor x
-- A mod that matches np.mod and python mod.
def pmod(x: Float, y: Float) -> Float = x - floor(x / y) * y
-- A mod that matches np.fmod and c fmod.
def fmod(x: Float, y: Float) -> Float = x - truncate (x / y) * y
def Vec(dim|Ix) -> Type = dim=>Float
def sq_norm(r: Vec dim) -> Float given (dim|Ix) = sum $ for i. r[i] * r[i]
def norm(r: Vec dim) -> Float given (dim|Ix) = sqrt $ sq_norm r

Useful Quantities

This computes a size for a box of the given number of dimensions, such that the given number of particles will fill it with the given density.

def box_size_at_number_density(particle_count: Int, density: Float, dim: Int) -> Float = pow ((i_to_f particle_count) / density) (1.0 / (i_to_f dim))

Spaces

def Displacement(dim|Ix) -> Type = (Vec dim, Vec dim) -> Vec dim
def Shift(dim|Ix) -> Type = (Vec dim, Vec dim) -> Vec dim
def free_displacement() ->> Displacement dim given (dim|Ix) = \r_1 r_2. r_1 - r_2
def free_shift() ->> Shift dim given (dim|Ix) = \r dr. r + dr
def periodic_wrap(box: Float, dr: dim=>Float) -> (dim=>Float) given (dim|Ix) = for i. (pmod (dr[i] + box / 2.0) box) - box / 2.0
def periodic_displacement(box:Float) -> Displacement dim given (dim|Ix) = \r_1 r_2. periodic_wrap box (r_1 - r_2)
def periodic_shift(box:Float) -> Shift dim given (dim|Ix) = \r dr. for i. pmod (r[i] + dr[i]) box

Energy Functions

def Energy(n|Ix, dim|Ix) -> Type = (n=>Vec dim) -> Float

The "soft-sphere" energy, as a function of the displacement r.

def soft_sphere(ε: Float, α: Float, σ: Float, r: Float) -> Float = case r < σ of True -> ε / α * pow (1 - r / σ) α False -> 0.0
def harmonic_soft_sphere(sigma:Float) -> (Float) -> Float = \r. soft_sphere 1.0 2.0 sigma r

Here we have a naive pairwise energy function constructor, that promotes a two-particle energy to a whole-system energy by just applying it to every pair of distinct particles. We start with this to have something to test with.

def pair_energy( energy: (Float) -> Float, displacement: Displacement dim, r: n=>Vec dim ) -> Float given (n|Ix, dim|Ix) = sum $ for i. sum for j:(..<i). energy $ norm $ displacement r[i] r[inject j]

Minimization Functions

Now we develop the FIRE Descent algorithm.

struct FireDescentState(n|Ix, dim|Ix) = R: (n=>Vec dim) V: (n=>Vec dim) F: (n=>Vec dim) dt: Float alpha: Float n_pos: Int
def fire_descent_init( dt: Float, alpha: Float, energy: Energy(n, dim), r: n=>Vec dim ) -> FireDescentState n dim given (n|Ix, dim|Ix) = force = \rp. -(grad energy) rp V = for i:n j:dim. 0.0 F = force r n_pos: Int = 0 FireDescentState(r, V, F, dt, alpha, n_pos)
def fire_descent_step( shift: Shift dim, energy: Energy n dim, state: FireDescentState n dim ) -> FireDescentState n dim given (n|Ix, dim|Ix) = -- Constants that parameterize the FIRE algorithm. -- TODO: Thread these constants through somehow. -- dougalm@ is there a canonical way to do this? dt_start = 0.1 dt_max = 0.4 n_min = 5 f_inc = 1.1 f_dec = 0.5 f_alpha = 0.99 alpha_start = 0.1 ε = 0.000000001 force = \r. -(grad energy) r -- FIRE algorithm. -- (MkFireDescentState {R, V, F, dt, alpha, n_pos}) = state dt = state.dt alpha = state.alpha n_pos = state.n_pos -- Do a Velocity-Verlet step. R = for i. shift state.R[i] (state.V[i] *. dt + state.F[i] *. pow dt 2) F_old = state.F F = force R V = state.V + dt * 0.5 .* (F_old + F) -- Rescale the velocity. F_norm = sqrt $ sum $ sum for i j. pow F[i, j] 2 V_norm = sqrt $ sum $ sum for i j. pow V[i, j] 2 V = V + alpha .* (F *. V_norm / (F_norm + ε) - V) -- Check whether the force is aligned with the velocity. FdotV = sum $ sum for i j. F[i, j] * V[i, j] -- Decide whether to increase the speed of the simulation or reduce it. (n_pos, dt, alpha) = if FdotV >= 0.0 then dt' = if n_pos >= n_min then (min (dt * f_inc) dt_max) else dt alpha' = if n_pos >= n_min then (alpha * f_alpha) else alpha (n_pos + 1, dt', alpha') else (0, dt * f_dec, alpha_start) V = if FdotV >= 0.0 then V else zero FireDescentState(R, V, F, dt, alpha, n_pos)

Drawing

import png
import diagram

Now a tool to draw a two-dimensional system, where each particle is a disk of given size.

TwoDimensions = Fin 2
def draw_system(radius, r: n=>Vec TwoDimensions) -> Diagram given (n|Ix) = disks = concat_diagrams for i. circle radius | move_xy(r[i, (0 @ TwoDimensions)], r[i, (1 @ TwoDimensions)]) disks

Example

Initialize a system randomly.

N_small = 500
d = 2
L_small = box_size_at_number_density (n_to_i N_small) 1.2 (n_to_i d)
-- We will simulate in a box of this side length
L_small
20.41241
R_init_small = rand_mat N_small d (\k. L_small * rand k) (new_key 0)

The initial state of our random system

%time :html render_svg (draw_system 0.5 R_init_small) (Point(0.0, 0.0), Point(L_small, L_small))
Compile time: 290.140 ms Run time: 26.617 ms

Define energy function. Note the preiodic_displacement, which means our system will be evolving on a torus.

def energy(pos: n=>d=>Float) -> Float given (n|Ix, d|Ix) = pair_energy (harmonic_soft_sphere 1.0) (periodic_displacement L_small) pos

Here's the initial energy we compute for our system.

:t energy R_init_small
Float32
energy R_init_small
74.69006

Initialize a simulation

state_small = fire_descent_init 0.1 0.1 energy R_init_small

and test one step of minimization. The energy decreases from the initial, as expected:

energy $ fire_descent_step(free_shift, energy, state_small).R
71.78407

Now we can test that our code basically works by running 100 steps of minimization.

%time (state_small', energies) = scan state_small \i:(Fin 100) s. s' = fire_descent_step (periodic_shift L_small) energy s (s', energy $ s'.R)
Compile time: 729.278 ms Run time: 1.070 s

Here's how the energy decreases over time.

%time :html show_plot $ y_plot energies
Compile time: 522.516 ms Run time: 5.256 ms

Here's what the system looks like after minimization.

%time :html render_svg (draw_system 0.5 (state_small'.R)) (Point(0.0, 0.0), Point(L_small, L_small))
Compile time: 306.320 ms Run time: 26.259 ms

Neighbors optimization

The above pair_energy function will compute the influence of every atom on every other atom, regardless of how far apart they are.

To simulate more efficiently, we'd like to approximate the pairwise energy with an energy that only includes contributions from atoms that are close enough to each other that we wish not to neglect them.

This is a two-step operation:

  • Break the simulation volume into a grid of cells, and do a linear pass over the atoms to group them by which cell each is in.
  • Traverse every pair of adjacent cells and compute energy terms for every pair of atoms only in those cells, and no others.

Bounded lists

We start with an abstraction of an incrementally growable list. To get O(1) insertion at the end, we (currently) have to give an upper bound for the list's size.

-- TODO Can we encapsulate this BoundedList type as a `data` and still
-- define in-place mutation operations on it?
struct BoundedList(n|Ix, a|Data) = lim : n buf : n=>a
def unsafe_next_index(ix:n) -> n given (n|Ix) = unsafe_from_ordinal $ ordinal ix + 1
def empty_bounded_list(dummy_val: a) -> BoundedList n a given (a|Data, n|Ix) = BoundedList(unsafe_from_ordinal 0, for _. dummy_val)
-- The point of a `BoundedList` is O(1) push by in-place mutation
def bd_push(ref: Ref h (BoundedList n a), x: a) -> {State h} () given (h, a|Data, n|Ix) = sz = get ref.lim if ordinal sz < size n then ref.buf!sz := x ref.lim := unsafe_next_index sz else todo -- throw ()
-- Once we're done pushing, we can compact a `BoundedList` into a standard `List`.
def as_list(lst: BoundedList n a) -> List a given (a|Data, n|Ix) = n_result = ordinal lst.lim AsList _ $ for i:(Fin n_result). lst.buf[unsafe_from_ordinal $ ordinal i]

Cell list

We define a single index for the whole grid.

def GridIx(dim|Ix, grid_size) -> Type = dim => (Fin grid_size)

A cell list is now just a BoundedList of the (indices of) the atoms that appear in each cell in the grid.

def CellTable(dim|Ix, grid_size, bucket_size, atom_ix|Data) -> Type = GridIx dim grid_size => BoundedList (Fin bucket_size) atom_ix
-- Compute the cell an atom should go into.
def target_cell(cell_size: Float, atom: Vec dim) -> GridIx dim grid_size given (dim|Ix, grid_size) = for dim. from_ordinal $ f_to_n $ atom[dim] / cell_size
-- A traversal of an atom array together with target cell.
-- We abstract this because we will use it twice.
def traverse_cells( grid_size, cell_size, atoms: atom_ix => Vec dim, action: (atom_ix, GridIx(dim, grid_size)) -> {|eff} () ) -> {|eff} () given (dim|Ix, atom_ix|Ix, eff) = for_ ix. cell = target_cell cell_size atoms[ix] action ix cell

Here is the actual cell table computation:

def cell_table( grid_size, bucket_size, cell_size, atoms: atom_ix => Vec dim ) -> CellTable(dim, grid_size, bucket_size, atom_ix) given (dim|Ix, atom_ix|Ix) = yield_state (for _. empty_bounded_list $ unsafe_from_ordinal 0) \ref. traverse_cells grid_size cell_size atoms \ix cell. bd_push ref!cell ix
-- This is a helper for computing the bucket size. Right now we end
-- up traversing the input twice (once to compute a size and once to
-- actually build the cell table); this is working around limitations
-- in Dex's support for mutable lists of statically unknown length.
def cell_table_bucket_size( grid_size, cell_size, atoms: atom_ix => Vec dim ) -> Nat given (dim|Ix, atom_ix|Ix) = -- TODO Use yield_accum here bucket_sizes = yield_state (for _:(GridIx dim grid_size). 0) \ref. traverse_cells grid_size cell_size atoms \ix cell. ref!cell := get ref!cell + 1 maximum bucket_sizes

Let's test it out on our "small" system:

def grid_params(L, desired_cell_size) = cells_per_side = floor(L / desired_cell_size) cell_size = L / cells_per_side grid_size = unsafe_i_to_n $ f_to_i cells_per_side (cell_size, grid_size)
desired_cell_size = 1.0 -- unit interaction range
(cell_size, grid_size) = grid_params L_small desired_cell_size
bucket_size = 10
%time tbl = cell_table grid_size bucket_size cell_size $ state_small'.R
Compile time: 138.761 ms Run time: 75.500 us

We have a table of cells with atoms in them

:t tbl
(((Fin 2) => Fin 20) => BoundedList (Fin 10) (Fin 500))

And here are the atoms in the 0th cell.

as_list tbl[unsafe_from_ordinal 0]
(AsList 2 [33, 237])

Now let's compute pairs of neighbors from our cell list

We'll specialize to two dimensions for now, but broadening to more is not difficult.

cell_neighbors_2d = [[-1, -1], [-1, 0], [-1, 1], [0, -1], [0, 0], [0, 1], [1, -1], [1, 0], [1, 1]]
:t cell_neighbors_2d
((Fin 9) => (Fin 2) => Int32)
-- Toroidal index offsetting
def torus_offset(ix: (Fin n), offset: Int) -> (Fin n) given (n) = unsafe_from_ordinal $ unsafe_i_to_n $ mod (n_to_i (ordinal ix) + offset) (n_to_i n)
-- A traversal of pairs of atoms from adjacent or equal cells in a
-- cell table.
def traverse_pairs( tbl: CellTable(TwoDimensions, grid_size, bucket_size, atom_ix), atoms: atom_ix => Vec TwoDimensions, action: (atom_ix, atom_ix) -> {|eff} () ) -> {|eff} () given (grid_size, bucket_size, atom_ix|Ix, eff) = for_ cell_ix : (GridIx TwoDimensions grid_size). for_ nb. displacement = cell_neighbors_2d[nb] neighbor_ix = for dim. torus_offset cell_ix[dim] displacement[dim] AsList(sz_atoms_1, atoms_1) = as_list tbl[cell_ix] for_ atom1_ix : (Fin sz_atoms_1). atom1 = atoms_1[atom1_ix] AsList(sz_atoms_2, atoms_2) = as_list tbl[neighbor_ix] for_ atom2_ix : (Fin sz_atoms_2). atom2 = atoms_2[atom2_ix] action atom1 atom2

The neighbor list computation. The point of the exercise is that this is not O(#atoms^2), but rather O(#cells) * 9 * O(#atoms per cell^2), because it only considers atoms from adjacent cells as potential neighbors.

def neighbor_list( neighbor_list_size, tbl: CellTable(TwoDimensions, grid_size, bucket_size, atom_ix), is_neighbor: (atom_ix, atom_ix) -> Bool, atoms: atom_ix => Vec TwoDimensions ) -> BoundedList(Fin neighbor_list_size, (atom_ix, atom_ix)) given (grid_size, bucket_size, atom_ix|Ix) = yield_state(empty_bounded_list((from_ordinal 0, from_ordinal 0))) \ref. traverse_pairs(tbl, atoms) \atom1 atom2. if is_neighbor(atom1, atom2) then bd_push(ref, (atom1, atom2))
-- Also a helper to pre-compute how large a buffer to allocate for the
-- neighbor list, by traversing the input an extra time.
def count_neighbor_list_size( tbl: CellTable(TwoDimensions, grid_size, bucket_size, atom_ix), is_neighbor: (atom_ix, atom_ix) -> Bool, atoms: atom_ix => Vec TwoDimensions ) -> Nat given (grid_size, bucket_size, atom_ix|Ix) = yield_accum (AddMonoid Nat) \ref. traverse_pairs tbl atoms (\atom1 atom2. if is_neighbor atom1 atom2 then ref += 1)
def periodic_near( desired_cell_size, L, atoms: atom_ix => Vec TwoDimensions ) -> (atom_ix, atom_ix) -> Bool given (atom_ix|Ix) = \atom1 atom2. dist = norm $ periodic_displacement(L)(atoms[atom1], atoms[atom2]) dist < desired_cell_size

Let's test this one out on our "small" system.

%time res = (neighbor_list 4000 tbl (periodic_near 1.0 L_small $ state_small'.R) $ state_small'.R)
Compile time: 153.059 ms Run time: 1.346 ms

In that configuration, we find this many pairs of neighbors:

AsList(k, _) = as_list res
k
3090

Now that we have the concept of neighbor lists, we cen define a variant of pair_energy that only considers atoms that the neighbor list says are close.

def pair_energy_nl( energy: (Float) -> Float, displacement: Displacement dim, r: n=>Vec dim, neighbors: List (n, n) ) -> Float given (dim|Ix, n|Ix) = AsList(k, nbrs) = neighbors sum for i. (a1ix, a2ix) = nbrs[i] case (ordinal a1ix) < (ordinal a2ix) of True -> energy $ norm $ displacement r[a1ix] r[a2ix] False -> 0.0
def energy_nl( L, neighbors: List (n, n) ) -> (n=>d=>Float) -> Float given (d|Ix, n|Ix) = \pos. pair_energy_nl (harmonic_soft_sphere 1.0) (periodic_displacement L) pos neighbors

And here's the check that it computes near-identical results to our original, fully pairwise energy function.

energy_nl(L_small, as_list res)(state_small'.R)
1.230771
energy (state_small'.R)
1.230771
-- Package the above up into a function that just computes the
-- neighbor list from an array of atoms.
def just_neighbor_list( desired_cell_size, L, atoms:n=>(Fin 2)=>Float ) -> (List (n, n)) given (n|Ix) = (cell_size, grid_size) = grid_params L desired_cell_size bucket_size = cell_table_bucket_size grid_size cell_size atoms tbl = cell_table grid_size bucket_size cell_size atoms is_neighbor = periodic_near desired_cell_size L atoms neighbor_list_sz = count_neighbor_list_size tbl is_neighbor atoms res = neighbor_list neighbor_list_sz tbl is_neighbor atoms as_list res

The neighbor list energy function works as an argument to our FIRE Descent algorithm (provided the neighbor list uses the same atoms).

state_nl = energy_func = (energy_nl L_small $ just_neighbor_list 1.0 L_small R_init_small) fire_descent_init 0.1 0.1 energy_func R_init_small
energy R_init_small
74.69006
energy $ fire_descent_step(free_shift, energy, state_small).R
71.78407
-- A helper for short-circuiting `any` computation
def fast_any(f: (n) -> {|eff} Bool) -> {|eff} Bool given (n|Ix, eff) = iter \ct. if ct >= size n then Done False else if f (ct @ n) then Done True else Continue

And now that this basically works, we can package the whole thing up as a simulation. We have another trick here: we compute the neighbor list with an extra "halo", treating atoms as neighbors if they are distance 1 + halo from each other, rather than just the interaction range 1. This way, we only have to recompute the neighbor list when some atom moves more than halo/2 away from where it was when the neighbor list is computed, because otherwise it remains a sound approximation.

-- TODO(Issue 1133) Can't use scan with a body that has an effect?
def simulate( displacement : Displacement TwoDimensions, halo_size, L, time|Ix, state : FireDescentState(atom_ix, TwoDimensions) ) -> {IO} (FireDescentState(atom_ix, TwoDimensions), time => Float) given (atom_ix|Ix) = with_state state.R \saved_atoms_ref. nbrs = just_neighbor_list (1.0 + halo_size) L $ state.R AsList(k, _) = nbrs print $ show k <> " initial neighbor list size" with_state nbrs \saved_neighbors_ref. swap $ run_state state \s_ref. for i. s = get s_ref new_atoms = s.R rebuild = fast_any \i. 2 * norm (displacement (get saved_atoms_ref!i) new_atoms[i]) > halo_size if rebuild then saved_atoms_ref := new_atoms nbrs = just_neighbor_list (1.0 + halo_size) L new_atoms saved_neighbors_ref := nbrs AsList(k, _) = nbrs print $ show k <> " new neighbor list size" nbrs = get saved_neighbors_ref s' = fire_descent_step (periodic_shift L) (energy_nl L nbrs) s s_ref := s' energy_nl(L, nbrs)(s'.R)

Let's check that it works on our test system from before

%time (state_nl', energies_nl) = unsafe_io \. simulate (periodic_displacement L_small) 0.5 L_small (Fin 100) state_small
4568 initial neighbor list size
4614 new neighbor list size
4564 new neighbor list size
4376 new neighbor list size
4346 new neighbor list size
4266 new neighbor list size
4178 new neighbor list size
4140 new neighbor list size
4100 new neighbor list size
4028 new neighbor list size
4006 new neighbor list size
3922 new neighbor list size
3868 new neighbor list size
3810 new neighbor list size
3762 new neighbor list size
3720 new neighbor list size
3656 new neighbor list size
3640 new neighbor list size
3602 new neighbor list size
3572 new neighbor list size
3552 new neighbor list size
Compile time: 2.060 s Run time: 69.556 ms
%time :html show_plot $ y_plot energies_nl
Compile time: 476.051 ms Run time: 4.799 ms
%time :html render_svg (draw_system 0.5 state_nl'.R) (Point(0.0, 0.0), Point(L_small, L_small))
Compile time: 263.806 ms Run time: 27.630 ms

But of course the point of the exercise is that this now scales up to larger systems because it avoids the quadratic energy computation.

N_large = if not (dex_test_mode()) then 50000 else 500
L_large = box_size_at_number_density (n_to_i N_large) 1.2 (n_to_i d)
L_large
204.1241
R_init_large = rand_mat N_large d (\k. L_large * rand k) (new_key 0)

Initial state (we render the atoms smaller now so they don't over-plot too badly).

%time :html render_svg (draw_system 0.2 R_init_large) (Point(0.0, 0.0), Point(L_large, L_large))
Compile time: 776.498 ms Run time: 2.603 s
state_large = energy_func = (energy_nl L_large $ just_neighbor_list 1.0 L_large R_init_large) fire_descent_init 0.1 0.1 energy_func R_init_large
%time (state_large_nl', energies_large_nl) = unsafe_io \. simulate (periodic_displacement L_large) 0.5 L_large (Fin 100) state_large
474778 initial neighbor list size
474090 new neighbor list size
471628 new neighbor list size
465540 new neighbor list size
453724 new neighbor list size
442298 new neighbor list size
434298 new neighbor list size
428432 new neighbor list size
421932 new neighbor list size
418256 new neighbor list size
414656 new neighbor list size
411276 new neighbor list size
404290 new neighbor list size
398562 new neighbor list size
393958 new neighbor list size
390030 new neighbor list size
386782 new neighbor list size
382868 new neighbor list size
379510 new neighbor list size
376428 new neighbor list size
374024 new neighbor list size
371272 new neighbor list size
368946 new neighbor list size
366958 new neighbor list size
365340 new neighbor list size
363842 new neighbor list size
362166 new neighbor list size
360926 new neighbor list size
360118 new neighbor list size
359558 new neighbor list size
358906 new neighbor list size
358038 new neighbor list size
357402 new neighbor list size
357078 new neighbor list size
Compile time: 1.797 s Run time: 13.124 s

Energy decrease

%time :html show_plot $ y_plot energies_large_nl
Compile time: 507.910 ms Run time: 3.984 ms

System state after minimization.

%time :html render_svg (draw_system 0.2 state_large_nl'.R) (Point(0.0, 0.0), Point(L_large, L_large))
Compile time: 1.011 s Run time: 2.928 s